package main

import (
	"bufio"
	"context"
	"flag"
	"fmt"
	"io"
	"os"
	"os/exec"
	"strings"
	"time"

	"github.com/fsnotify/fsnotify"
	"github.com/google/shlex"
	"github.com/pachyderm/pachyderm/v2/src/internal/errors"
	"github.com/pachyderm/pachyderm/v2/src/internal/log"
	"github.com/pachyderm/pachyderm/v2/src/internal/pctx"
	"go.uber.org/zap"
	"golang.org/x/exp/slices"
	"golang.org/x/sync/errgroup"
)

func main() {
	log.InitPachctlLogger()
	ctx := pctx.Background("test-runner")
	tags := flag.String("tags", "", "Tags to run, for example k8s. Tests without this flag will not be selected.")
	fileName := flag.String("file", "tests_to_run.csv", "csv file containing the list of test names to run generated by the collector.")
	gotestsumArgsRaw := flag.String("gotestsum-args", "", "Additional arguments to pass to the gotestsum portion of the test command.")
	gotestArgsRaw := flag.String("gotest-args", "", "Additional arguments to pass to the 'go test' portion of the test command.")
	shard := flag.Int("shard", 0, "0 indexed current runner index that we are on.")
	totalShards := flag.Int("total-shards", 1, "Total number of runners that we are sharding over.")
	threadPool := flag.Int("threads", 1, "Number of tests to execute concurrently.")
	flag.Parse()
	gotestsumArgs, err := shlex.Split(*gotestsumArgsRaw)
	if err != nil {
		log.Exit(ctx, "Error parsing gotestsumArgs", zap.Error(err))
	}
	gotestArgs, err := shlex.Split(*gotestArgsRaw)
	if err != nil {
		log.Exit(ctx, "Error parsing gotestArgs", zap.Error(err))
	}

	err = run(ctx,
		*tags,
		*fileName,
		gotestsumArgs,
		gotestArgs,
		*shard,
		*totalShards,
		*threadPool,
	)
	if err != nil {
		log.Exit(ctx, "Error running tests", zap.Error(err))
	}
	os.Exit(0)
}

func run(ctx context.Context, tags string, fileName string, gotestsumArgs []string, gotestArgs []string, shard int, totalShards int, threadPool int) error {
	tests, err := readTests(ctx, fileName)
	if err != nil {
		return errors.Wrapf(err, "reading file %v", fileName)
	}
	slices.Sort(tests) // sort so we shard them from a pre-determined order

	// loop through by the number of shards so that each gets a roughly equal number on each
	testsForShard := map[string][]string{}
	for idx := shard; idx < len(tests); idx += totalShards {
		val := strings.Split(tests[idx], ",")
		if len(val) < 1 {
			return errors.Errorf("error parsing test name and package to run. Value: %v", tests[idx])
		}
		pkg := val[0]
		testName := val[1]
		// index all tests by package as we collect the ones for this shard. This lets
		// us run all tests in each package on this shard with one `go test` command, preserving the serial
		// running of tests without t.parallel the same way that go test ./... would since
		// got test also runs packages in paralllel.
		if _, ok := testsForShard[pkg]; !ok {
			testsForShard[pkg] = []string{testName}
		} else {
			testsForShard[pkg] = append(testsForShard[pkg], testName)
		}
	}

	eg, _ := errgroup.WithContext(ctx)
	eg.SetLimit(threadPool)
	for pkg, tests := range testsForShard {
		threadLocalPkg := pkg
		threadLocalTests := tests
		eg.Go(func() error {
			return errors.EnsureStack(runTest(threadLocalPkg, threadLocalTests, tags, gotestsumArgs, gotestArgs))
		})
	}
	err = eg.Wait()

	if err != nil {
		return errors.EnsureStack(err)
	}
	return nil
}

func readTests(ctx context.Context, fileName string) (_ []string, retErr error) {
	if _, err := os.Stat(fileName); err != nil {
		err := waitForTestListFile(fileName)
		if err != nil {
			return nil, err
		}
	}
	tests := []string{}
	file, err := os.Open(fileName)
	if err != nil {
		return nil, errors.EnsureStack(err)
	}
	defer errors.Close(&retErr, file, "close input %v", fileName)
	scanner := bufio.NewScanner(file)
	for scanner.Scan() {
		tests = append(tests, scanner.Text())
	}
	if err := scanner.Err(); err != nil {
		return nil, errors.EnsureStack(err)
	}
	return tests, nil
}

func waitForTestListFile(fileName string) (retErr error) {
	watcher, err := fsnotify.NewWatcher()
	if err != nil {
		return errors.EnsureStack(err)
	}
	defer errors.Close(&retErr, watcher, "close fsnotify watcher")
	err = watcher.Add(".")
	if err != nil {
		return errors.EnsureStack(err)
	}
	timer := time.NewTimer(9 * time.Minute)
	defer timer.Stop()
	for {
		select {
		case event, ok := <-watcher.Events:
			if !ok {
				return errors.Errorf("Error waiting for test list file")
			}
			if event.Has(fsnotify.Create) && event.Name == fmt.Sprintf("./%s", fileName) {
				return nil
			}
		case err, ok := <-watcher.Errors:
			if !ok {
				return errors.EnsureStack(err)
			}
		case <-timer.C:
			return errors.Errorf("Timed out waiting for test list file")
		}
	}
}

// run tests with `go test`. We run one package at a time so tests with the same name in different packages
// like TestConfig or TestDebug don't run multiple times if they land on separate shards.
func runTest(pkg string, testNames []string, tags string, gotestsumArgs []string, gotestArgs []string) error {
	resultsFolder := os.Getenv("TEST_RESULTS")
	if resultsFolder == "" {
		return errors.Errorf("TEST_RESULTS environment variable must be set.")
	}
	pkgShort := strings.ReplaceAll(strings.TrimPrefix(pkg, "github.com/pachyderm/pachyderm/v2/"), "/", "-")
	runTestArgs := []string{
		"--raw-command",
		fmt.Sprintf("--packages=%s", pkg),
		"--rerun-fails",
		"--rerun-fails-max-failures=1",
		"--debug",
		fmt.Sprintf("--junitfile=%s/circle/gotestsum-report-%s.xml", resultsFolder, pkgShort),
		fmt.Sprintf("--jsonfile=%s/%s-go-test-results.jsonl", resultsFolder, pkgShort),
	}
	if len(gotestsumArgs) > 0 {
		runTestArgs = append(runTestArgs, gotestsumArgs...)
	}
	testRegex := strings.Builder{}
	for _, test := range testNames {
		if testRegex.Len() > 0 {
			testRegex.WriteString("|")
		}
		testRegex.WriteString(fmt.Sprintf("^%s$", test))
	}
	runTestArgs = append(runTestArgs, "--", "go", "test", pkg, "-json",
		fmt.Sprintf("-tags=%s", tags),
		fmt.Sprintf("-run=%s", testRegex.String()),
	)
	if len(gotestArgs) > 0 {
		runTestArgs = append(runTestArgs, gotestArgs...)
	}

	cmd := exec.Command("gotestsum", runTestArgs...)
	cmd.Env = os.Environ()
	cmd.Env = append(cmd.Env, "CGO_ENABLED=0", "GOCOVERDIR=\"/tmp/test-results/\"")
	fmt.Printf("Running command %v\n", cmd.String())
	testsOutput, err := cmd.CombinedOutput()
	_, copyErr := io.Copy(os.Stdout, strings.NewReader(string(testsOutput)))
	if err != nil {
		return errors.EnsureStack(err)
	}
	return errors.EnsureStack(copyErr)

}
